{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning: algorithms, objectives, and assumptions\n",
    "(c) Deniz Yuret 2019\n",
    "\n",
    "In this notebook we will analyze three classic learning algorithms.\n",
    "* **Perceptron:** ([Rosenblatt, 1957](https://en.wikipedia.org/wiki/Perceptron)) a neuron model trained with a simple algorithm that updates model weights using the input when the prediction is wrong.\n",
    "* **Adaline:** ([Widrow and Hoff, 1960](https://en.wikipedia.org/wiki/ADALINE)) a neuron model trained with a simple algorithm that updates model weights using the error multiplied by the input (aka least mean square (LMS), delta learning rule, or the Widrow-Hoff rule).\n",
    "* **Softmax classification:** ([Cox, 1958](https://en.wikipedia.org/wiki/Multinomial_logistic_regression)) a multiclass generalization of the logistic regression model from statistics (aka multinomial LR, softmax regression, maxent classifier etc.).\n",
    "\n",
    "We will look at these learners from three different perspectives:\n",
    "* **Algorithm:** First we ask only **how** the learner works, i.e. how it changes after observing each example.\n",
    "* **Objectives:** Next we ask **what** objective guides the algorithm, whether it is optimizing a particular objective function, and whether we can use a generic *optimization algorithm* instead.\n",
    "* **Assumptions:** Finally we ask **why** we think this algorithm makes sense, what prior assumptions does this imply and whether we can use *probabilistic inference* for optimal learning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "72"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using Knet, Plots, Statistics, LinearAlgebra, Random\n",
    "Base.argmax(a::KnetArray) = argmax(Array(a))\n",
    "Base.argmax(a::AutoGrad.Value) = argmax(value(a))\n",
    "ENV[\"COLUMNS\"] = 72"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Loading MNIST...\n",
      "└ @ Main /kuacc/users/dyuret/.julia/packages/Knet/LjPts/data/mnist.jl:33\n"
     ]
    }
   ],
   "source": [
    "include(Knet.dir(\"data/mnist.jl\"))\n",
    "xtrn, ytrn, xtst, ytst = mnist()\n",
    "ARRAY = Array{Float32}\n",
    "xtrn, xtst = ARRAY(mat(xtrn)), ARRAY(mat(xtst))\n",
    "onehot(y) = (m=ARRAY(zeros(maximum(y),length(y))); for i in 1:length(y); m[y[i],i]=1; end; m)\n",
    "ytrn, ytst = onehot(ytrn), onehot(ytst);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "784×60000 Array{Float32,2}\n",
      "10×60000 Array{Float32,2}\n",
      "784×10000 Array{Float32,2}\n",
      "10×10000 Array{Float32,2}\n"
     ]
    }
   ],
   "source": [
    "println.(summary.((xtrn, ytrn, xtst, ytst)));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(60000, 10000, 784, 10)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "NTRN,NTST,XDIM,YDIM = size(xtrn,2), size(xtst,2), size(xtrn,1), size(ytrn,1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10×784 Array{Float32,2}:\n",
       " -0.723526    0.984467   1.53663   …  -1.02691    0.827406 \n",
       " -1.23713    -0.130535  -0.426584     -2.98049    0.236315 \n",
       " -1.01081     0.599193   0.39188      -0.254158  -0.0685053\n",
       "  1.95004     1.12775   -0.270512      1.50776    0.38962  \n",
       " -1.65777    -1.30566    1.49927      -0.288618   1.3921   \n",
       "  0.247336    0.37563    1.53998   …  -0.227825   0.81309  \n",
       " -0.0449449  -0.775154   0.535788      0.749737  -0.534139 \n",
       "  0.254892   -0.444621   0.640151      0.130355   0.64397  \n",
       " -1.21859    -0.588326   1.21659       0.408257   0.723103 \n",
       "  0.557634    0.340867   0.159563      0.540516   0.63455  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Model weights\n",
    "w = ARRAY(randn(YDIM,XDIM))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10×60000 Array{Float32,2}:\n",
       "   4.96523  -7.80919  -24.0313    -0.364504  …  -10.8263    -9.00563 \n",
       "  13.2433    3.92816    1.89772   -9.89675       -4.45021  -15.0672  \n",
       " -12.9268    2.97658   -7.45577  -16.8185        -3.89602  -10.5781  \n",
       "   2.05194  -2.72613   -8.16912    0.715291       3.48574    4.46081 \n",
       "  -6.13742  -5.74059    1.41782  -17.024         -1.11403    0.278054\n",
       "   9.04555   4.09188    3.02878    7.18583   …    9.06085   -6.22643 \n",
       "  12.2155    9.27237    4.17152   -6.02936        4.86192   -2.0607  \n",
       "  -7.35357  -2.72163   -4.7295     3.90331        3.58258    1.90419 \n",
       "  -7.79106  -3.47143   -6.35381    2.64577       -7.66081  -19.9355  \n",
       "  -6.70359  -7.68111   -6.4374   -11.2293         3.27617   -9.95783 "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Class scores\n",
    "w * xtrn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1×60000 Adjoint{Int64,Array{Int64,1}}:\n",
       " 2  7  7  6  2  6  4  7  6  7  7  …  4  4  6  4  7  6  7  2  7  6  4"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Predictions\n",
    "[ argmax(w * xtrn[:,i]) for i in 1:NTRN ]'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1×60000 Adjoint{Int64,Array{Int64,1}}:\n",
       " 5  10  4  1  9  2  1  3  1  4  3  …  8  9  2  9  5  1  8  3  5  6  8"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Correct answers\n",
    "[ argmax(ytrn[:,i]) for i in 1:NTRN ]'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.0833, 0.0746)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Accuracy\n",
    "acc(w,x,y) = mean(argmax(w * x, dims=1) .== argmax(y, dims=1))\n",
    "acc(w,xtrn,ytrn), acc(w,xtst,ytst)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Algorithms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "train (generic function with 2 methods)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Training loop\n",
    "function train(algo,x,y,T=2^20)\n",
    "    w = ARRAY(zeros(size(y,1),size(x,1)))\n",
    "    nexamples = size(x,2)\n",
    "    nextprint = 1\n",
    "    for t = 1:T\n",
    "        i = rand(1:nexamples)\n",
    "        algo(w, x[:,i], y[:,i])  # <== this is where w is updated\n",
    "        if t == nextprint\n",
    "            println((iter=t, accuracy=acc(w,x,y), wnorm=norm(w)))\n",
    "            nextprint = min(2t,T)\n",
    "        end\n",
    "    end\n",
    "    w\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Perceptron"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "perceptron (generic function with 1 method)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function perceptron(w,x,y)\n",
    "    guess = argmax(w * x)\n",
    "    class = argmax(y)\n",
    "    if guess != class\n",
    "        w[class,:] .+= x\n",
    "        w[guess,:] .-= x\n",
    "    end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(iter = 1, accuracy = 0.09871666666666666, wnorm = 16.555326f0)\n",
      "(iter = 2, accuracy = 0.12575, wnorm = 19.439667f0)\n",
      "(iter = 4, accuracy = 0.13355, wnorm = 20.399765f0)\n",
      "(iter = 8, accuracy = 0.20878333333333332, wnorm = 25.85245f0)\n",
      "(iter = 16, accuracy = 0.11148333333333334, wnorm = 38.57418f0)\n",
      "(iter = 32, accuracy = 0.32356666666666667, wnorm = 48.58623f0)\n",
      "(iter = 64, accuracy = 0.2728, wnorm = 66.928535f0)\n",
      "(iter = 128, accuracy = 0.60465, wnorm = 89.28772f0)\n",
      "(iter = 256, accuracy = 0.6815666666666667, wnorm = 114.6649f0)\n",
      "(iter = 512, accuracy = 0.72925, wnorm = 141.36696f0)\n",
      "(iter = 1024, accuracy = 0.6892333333333334, wnorm = 177.80821f0)\n",
      "(iter = 2048, accuracy = 0.7991, wnorm = 229.15755f0)\n",
      "(iter = 4096, accuracy = 0.8287333333333333, wnorm = 280.51797f0)\n",
      "(iter = 8192, accuracy = 0.8581833333333333, wnorm = 337.9238f0)\n",
      "(iter = 16384, accuracy = 0.8295, wnorm = 409.2436f0)\n",
      "(iter = 32768, accuracy = 0.81145, wnorm = 485.44604f0)\n",
      "(iter = 65536, accuracy = 0.8569333333333333, wnorm = 588.75964f0)\n",
      "(iter = 131072, accuracy = 0.8549333333333333, wnorm = 707.2535f0)\n",
      "(iter = 262144, accuracy = 0.8878166666666667, wnorm = 854.45514f0)\n",
      "(iter = 524288, accuracy = 0.8979666666666667, wnorm = 1057.2964f0)\n",
      "(iter = 1048576, accuracy = 0.9036, wnorm = 1318.8921f0)\n",
      "  6.163184 seconds (7.99 M allocations: 4.509 GiB, 5.33% gc time)\n"
     ]
    }
   ],
   "source": [
    "# (iter = 1048576, accuracy = 0.8950333333333333, wnorm = 1321.2463f0) in 7 secs\n",
    "@time wperceptron = train(perceptron,xtrn,ytrn);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Adaline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "adaline (generic function with 1 method)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function adaline(w,x,y; lr=0.0001)\n",
    "    error = w * x - y\n",
    "    w .-= lr * error * x'\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(iter = 1, accuracy = 0.09751666666666667, wnorm = 0.0009055087f0)\n",
      "(iter = 2, accuracy = 0.15613333333333335, wnorm = 0.001271113f0)\n",
      "(iter = 4, accuracy = 0.175, wnorm = 0.0019275063f0)\n",
      "(iter = 8, accuracy = 0.18321666666666667, wnorm = 0.002950374f0)\n",
      "(iter = 16, accuracy = 0.15383333333333332, wnorm = 0.0043843426f0)\n",
      "(iter = 32, accuracy = 0.18926666666666667, wnorm = 0.008262319f0)\n",
      "(iter = 64, accuracy = 0.2615, wnorm = 0.014850053f0)\n",
      "(iter = 128, accuracy = 0.2799, wnorm = 0.025427587f0)\n",
      "(iter = 256, accuracy = 0.3463833333333333, wnorm = 0.04235847f0)\n",
      "(iter = 512, accuracy = 0.6255666666666667, wnorm = 0.06586557f0)\n",
      "(iter = 1024, accuracy = 0.6563333333333333, wnorm = 0.10303071f0)\n",
      "(iter = 2048, accuracy = 0.7575333333333333, wnorm = 0.1662105f0)\n",
      "(iter = 4096, accuracy = 0.7747833333333334, wnorm = 0.25526303f0)\n",
      "(iter = 8192, accuracy = 0.8031, wnorm = 0.35955703f0)\n",
      "(iter = 16384, accuracy = 0.82515, wnorm = 0.46123433f0)\n",
      "(iter = 32768, accuracy = 0.8383166666666667, wnorm = 0.5622118f0)\n",
      "(iter = 65536, accuracy = 0.8452666666666667, wnorm = 0.66491604f0)\n",
      "(iter = 131072, accuracy = 0.8482833333333333, wnorm = 0.7754936f0)\n",
      "(iter = 262144, accuracy = 0.84805, wnorm = 0.9207265f0)\n",
      "(iter = 524288, accuracy = 0.8513833333333334, wnorm = 1.089876f0)\n",
      "(iter = 1048576, accuracy = 0.8498, wnorm = 1.2931302f0)\n",
      " 26.039335 seconds (9.55 M allocations: 65.211 GiB, 4.94% gc time)\n"
     ]
    }
   ],
   "source": [
    "# (iter = 1048576, accuracy = 0.8523, wnorm = 1.2907721f0) in 31 secs with lr=0.0001\n",
    "@time wadaline = train(adaline,xtrn,ytrn);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Softmax classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "softmax (generic function with 1 method)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function softmax(w,x,y; lr=0.01)\n",
    "    probs = exp.(w * x)\n",
    "    probs ./= sum(probs)\n",
    "    error = probs - y\n",
    "    w .-= lr * error * x'\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(iter = 1, accuracy = 0.10441666666666667, wnorm = 0.10300317f0)\n",
      "(iter = 2, accuracy = 0.16301666666666667, wnorm = 0.13433683f0)\n",
      "(iter = 4, accuracy = 0.28513333333333335, wnorm = 0.17768018f0)\n",
      "(iter = 8, accuracy = 0.11805, wnorm = 0.24103987f0)\n",
      "(iter = 16, accuracy = 0.20396666666666666, wnorm = 0.3208787f0)\n",
      "(iter = 32, accuracy = 0.18268333333333334, wnorm = 0.50825554f0)\n",
      "(iter = 64, accuracy = 0.32871666666666666, wnorm = 0.7586994f0)\n",
      "(iter = 128, accuracy = 0.5855, wnorm = 1.1711375f0)\n",
      "(iter = 256, accuracy = 0.68105, wnorm = 1.8444412f0)\n",
      "(iter = 512, accuracy = 0.7839333333333334, wnorm = 2.7081473f0)\n",
      "(iter = 1024, accuracy = 0.8304833333333334, wnorm = 3.6876163f0)\n",
      "(iter = 2048, accuracy = 0.8382833333333334, wnorm = 4.728037f0)\n",
      "(iter = 4096, accuracy = 0.87475, wnorm = 5.873648f0)\n",
      "(iter = 8192, accuracy = 0.88845, wnorm = 7.210838f0)\n",
      "(iter = 16384, accuracy = 0.8984833333333333, wnorm = 8.7169895f0)\n",
      "(iter = 32768, accuracy = 0.9032833333333333, wnorm = 10.475012f0)\n",
      "(iter = 65536, accuracy = 0.9134333333333333, wnorm = 12.370967f0)\n",
      "(iter = 131072, accuracy = 0.9130333333333334, wnorm = 14.670891f0)\n",
      "(iter = 262144, accuracy = 0.9238, wnorm = 17.713814f0)\n",
      "(iter = 524288, accuracy = 0.92545, wnorm = 21.554945f0)\n",
      "(iter = 1048576, accuracy = 0.9237166666666666, wnorm = 26.564526f0)\n",
      " 25.238679 seconds (9.91 M allocations: 65.305 GiB, 4.52% gc time)\n"
     ]
    }
   ],
   "source": [
    "# (iter = 1048576, accuracy = 0.9242166666666667, wnorm = 26.523603f0) in 32 secs with lr=0.01\n",
    "@time wsoftmax = train(softmax,xtrn,ytrn);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Objectives"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "optimize (generic function with 1 method)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Training via optimization\n",
    "function optimize(loss,x,y; lr=0.1, iters=2^20)\n",
    "    w = Param(ARRAY(zeros(size(y,1),size(x,1))))\n",
    "    nexamples = size(x,2)\n",
    "    nextprint = 1\n",
    "    for t = 1:iters\n",
    "        i = rand(1:nexamples)\n",
    "        L = @diff loss(w, x[:,i], y[:,i])\n",
    "        ∇w = grad(L,w)\n",
    "        w .-= lr * ∇w\n",
    "        if t == nextprint\n",
    "            println((iter=t, accuracy=acc(w,x,y), wnorm=norm(w)))\n",
    "            nextprint = min(2t,iters)\n",
    "        end\n",
    "    end\n",
    "    w\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Perceptron minimizes the score difference between the correct class and the prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "perceptronloss (generic function with 1 method)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function perceptronloss(w,x,y)\n",
    "    score = w * x\n",
    "    guess = argmax(score)\n",
    "    class = argmax(y)\n",
    "    score[guess] - score[class]\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(iter = 1, accuracy = 0.09863333333333334, wnorm = 15.438191f0)\n",
      "(iter = 2, accuracy = 0.09863333333333334, wnorm = 15.438191f0)\n",
      "(iter = 4, accuracy = 0.14575, wnorm = 16.696978f0)\n",
      "(iter = 8, accuracy = 0.13358333333333333, wnorm = 26.538507f0)\n",
      "(iter = 16, accuracy = 0.25551666666666667, wnorm = 40.80332f0)\n",
      "(iter = 32, accuracy = 0.38035, wnorm = 49.807693f0)\n",
      "(iter = 64, accuracy = 0.3633166666666667, wnorm = 71.64613f0)\n",
      "(iter = 128, accuracy = 0.5086333333333334, wnorm = 86.83524f0)\n",
      "(iter = 256, accuracy = 0.5217, wnorm = 112.49046f0)\n",
      "(iter = 512, accuracy = 0.7327833333333333, wnorm = 142.34885f0)\n",
      "(iter = 1024, accuracy = 0.7181333333333333, wnorm = 184.21814f0)\n",
      "(iter = 2048, accuracy = 0.7697666666666667, wnorm = 218.95062f0)\n",
      "(iter = 4096, accuracy = 0.8230666666666666, wnorm = 269.77414f0)\n",
      "(iter = 8192, accuracy = 0.8184166666666667, wnorm = 332.04178f0)\n",
      "(iter = 16384, accuracy = 0.8632666666666666, wnorm = 408.29248f0)\n",
      "(iter = 32768, accuracy = 0.8854833333333333, wnorm = 499.76358f0)\n",
      "(iter = 65536, accuracy = 0.8928833333333334, wnorm = 594.98193f0)\n",
      "(iter = 131072, accuracy = 0.8764666666666666, wnorm = 722.84406f0)\n",
      "(iter = 262144, accuracy = 0.88925, wnorm = 868.8047f0)\n",
      "(iter = 524288, accuracy = 0.8974, wnorm = 1066.8756f0)\n",
      "(iter = 1048576, accuracy = 0.8918, wnorm = 1323.9752f0)\n",
      " 50.896200 seconds (183.62 M allocations: 76.122 GiB, 6.01% gc time)\n"
     ]
    }
   ],
   "source": [
    "# (iter = 1048576, accuracy = 0.8908833333333334, wnorm = 1322.4888f0) in 62 secs with lr=1\n",
    "@time wperceptron2 = optimize(perceptronloss,xtrn,ytrn,lr=1);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Adaline minimizes the squared error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "quadraticloss (generic function with 1 method)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function quadraticloss(w,x,y)\n",
    "    0.5 * sum(abs2, w * x - y)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(iter = 1, accuracy = 0.09736666666666667, wnorm = 0.0008820149f0)\n",
      "(iter = 2, accuracy = 0.16698333333333334, wnorm = 0.0012901156f0)\n",
      "(iter = 4, accuracy = 0.2466, wnorm = 0.0018553125f0)\n",
      "(iter = 8, accuracy = 0.17425, wnorm = 0.00295244f0)\n",
      "(iter = 16, accuracy = 0.23753333333333335, wnorm = 0.0044558235f0)\n",
      "(iter = 32, accuracy = 0.20073333333333335, wnorm = 0.0076521044f0)\n",
      "(iter = 64, accuracy = 0.32116666666666666, wnorm = 0.014028674f0)\n",
      "(iter = 128, accuracy = 0.4724833333333333, wnorm = 0.025002351f0)\n",
      "(iter = 256, accuracy = 0.5484333333333333, wnorm = 0.042062297f0)\n",
      "(iter = 512, accuracy = 0.6185333333333334, wnorm = 0.06702255f0)\n",
      "(iter = 1024, accuracy = 0.6920833333333334, wnorm = 0.105967544f0)\n",
      "(iter = 2048, accuracy = 0.7441166666666666, wnorm = 0.16895463f0)\n",
      "(iter = 4096, accuracy = 0.78585, wnorm = 0.25629073f0)\n",
      "(iter = 8192, accuracy = 0.8055833333333333, wnorm = 0.3571566f0)\n",
      "(iter = 16384, accuracy = 0.8308333333333333, wnorm = 0.4587624f0)\n",
      "(iter = 32768, accuracy = 0.83915, wnorm = 0.5649444f0)\n",
      "(iter = 65536, accuracy = 0.8459666666666666, wnorm = 0.6634013f0)\n",
      "(iter = 131072, accuracy = 0.84845, wnorm = 0.77901715f0)\n",
      "(iter = 262144, accuracy = 0.8511, wnorm = 0.9170074f0)\n",
      "(iter = 524288, accuracy = 0.8508, wnorm = 1.0950308f0)\n",
      "(iter = 1048576, accuracy = 0.84725, wnorm = 1.2915363f0)\n",
      " 58.777744 seconds (148.72 M allocations: 136.442 GiB, 4.69% gc time)\n"
     ]
    }
   ],
   "source": [
    "# (iter = 1048576, accuracy = 0.8498333333333333, wnorm = 1.2882874f0) in 79 secs with lr=0.0001\n",
    "@time wadaline2 = optimize(quadraticloss,xtrn,ytrn,lr=0.0001);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Softmax classifier maximizes the probabilities of correct answers\n",
    "(or minimizes negative log likelihood, aka cross-entropy or softmax loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "negloglik (generic function with 1 method)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function negloglik(w,x,y)\n",
    "    probs = exp.(w * x)\n",
    "    probs = probs / sum(probs)\n",
    "    class = argmax(y)\n",
    "    -log(probs[class])\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(iter = 1, accuracy = 0.09863333333333334, wnorm = 0.09697416f0)\n",
      "(iter = 2, accuracy = 0.1686, wnorm = 0.12180235f0)\n",
      "(iter = 4, accuracy = 0.26025, wnorm = 0.15152907f0)\n",
      "(iter = 8, accuracy = 0.25661666666666666, wnorm = 0.23889048f0)\n",
      "(iter = 16, accuracy = 0.15921666666666667, wnorm = 0.37189257f0)\n",
      "(iter = 32, accuracy = 0.2731166666666667, wnorm = 0.5208997f0)\n",
      "(iter = 64, accuracy = 0.5370666666666667, wnorm = 0.77327734f0)\n",
      "(iter = 128, accuracy = 0.6934833333333333, wnorm = 1.1921613f0)\n",
      "(iter = 256, accuracy = 0.6787666666666666, wnorm = 1.8622694f0)\n",
      "(iter = 512, accuracy = 0.7942333333333333, wnorm = 2.7099152f0)\n",
      "(iter = 1024, accuracy = 0.79965, wnorm = 3.6680388f0)\n",
      "(iter = 2048, accuracy = 0.8591666666666666, wnorm = 4.7135873f0)\n",
      "(iter = 4096, accuracy = 0.8641166666666666, wnorm = 5.8873725f0)\n",
      "(iter = 8192, accuracy = 0.8803833333333333, wnorm = 7.165706f0)\n",
      "(iter = 16384, accuracy = 0.89935, wnorm = 8.757448f0)\n",
      "(iter = 32768, accuracy = 0.9097166666666666, wnorm = 10.348306f0)\n",
      "(iter = 65536, accuracy = 0.9008166666666667, wnorm = 12.4738035f0)\n",
      "(iter = 131072, accuracy = 0.9122166666666667, wnorm = 14.797868f0)\n",
      "(iter = 262144, accuracy = 0.9197, wnorm = 17.693712f0)\n",
      "(iter = 524288, accuracy = 0.9240833333333334, wnorm = 21.383766f0)\n",
      "(iter = 1048576, accuracy = 0.9300166666666667, wnorm = 26.568132f0)\n",
      " 78.487115 seconds (279.97 M allocations: 114.282 GiB, 4.81% gc time)\n"
     ]
    }
   ],
   "source": [
    "# (iter = 1048576, accuracy = 0.9283833333333333, wnorm = 26.593485f0) in 120 secs with lr=0.01\n",
    "@time wsoftmax2 = optimize(negloglik,xtrn,ytrn,lr=0.01);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.3.1",
   "language": "julia",
   "name": "julia-1.3"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.3.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
